# datasets/dataset.py

import torch
from torch.utils.data import Dataset
import joblib


class TicketDataset(Dataset):
    def __init__(self, pkl_path):
        data = joblib.load(pkl_path)
        self.features = torch.tensor(data["features"], dtype=torch.float32)
        self.targets = torch.tensor(data["labels"], dtype=torch.float32).unsqueeze(1)

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, idx):
        return self.features[idx], self.targets[idx]
